package com.sunsy.authserver.config;

import com.baomidou.mybatisplus.core.metadata.OrderItem;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.MethodParameter;
import org.springframework.web.bind.support.WebDataBinderFactory;
import org.springframework.web.context.request.NativeWebRequest;
import org.springframework.web.method.support.HandlerMethodArgumentResolver;
import org.springframework.web.method.support.ModelAndViewContainer;

import javax.servlet.http.HttpServletRequest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;

/**
 * mybatis-plus order注入
 * @author 孙思远
 * @date 2021/7/19
 */
public class SqlFilterArgumentResolver implements HandlerMethodArgumentResolver {

  private final static String[] KEYWORDS = {"master", "truncate", "insert", "select"
      , "delete", "update", "declare", "alter", "drop", "sleep"};

  /**
   * 判断Controller是否包含page 参数
   *
   * @param parameter 参数
   * @return 是否过滤
   */
  @Override
  public boolean supportsParameter(MethodParameter parameter) {
    return parameter.getParameterType().equals(Page.class);
  }

  /**
   * @param parameter     入参集合
   * @param mavContainer  model 和 view
   * @param webRequest    web相关
   * @param binderFactory 入参解析
   * @return 检查后新的page对象
   * <p>
   * page 只支持查询 GET .如需解析POST获取请求报文体处理
   */
  @Override
  public Object resolveArgument(MethodParameter parameter, ModelAndViewContainer mavContainer
      , NativeWebRequest webRequest, WebDataBinderFactory binderFactory) {

    HttpServletRequest request = webRequest.getNativeRequest(HttpServletRequest.class);

    String[] ascs = request.getParameterValues("ascs");
    String[] descs = request.getParameterValues("descs");
    String current = request.getParameter("current");
    String size = request.getParameter("size");

    Page page = new Page();
    if (StringUtils.isNotBlank(current)) {
      page.setCurrent(Long.parseLong(current));
    }

    if (StringUtils.isNotBlank(size)) {
      page.setSize(Long.parseLong(size));
    }

    List<OrderItem> orderItemList = new ArrayList<>();
    Optional.ofNullable(ascs).ifPresent(s -> orderItemList.addAll(Arrays.stream(s)
        .filter(sqlInjectPredicate())
        .map(OrderItem::asc).collect(Collectors.toList())));
    Optional.ofNullable(descs).ifPresent(s -> orderItemList.addAll(Arrays.stream(s)
        .filter(sqlInjectPredicate())
        .map(OrderItem::desc).collect(Collectors.toList())));
    page.addOrder(orderItemList);

    return page;
  }

  /**
   * 判断用户输入里面有没有关键字
   *
   * @return Predicate
   */
  private Predicate<String> sqlInjectPredicate() {
    return sql -> {
      for (String keyword : KEYWORDS) {
        if (StringUtils.containsIgnoreCase(sql, keyword)) {
          return false;
        }
      }
      return true;
    };
  }

}
